import base64
import pickle
from typing import Any, Iterable, List, Optional, Tuple

from omagent_core.memories.ltms.ltm_base import LTMBase
from omagent_core.services.connectors.milvus import MilvusConnector
from omagent_core.utils.registry import registry
from pydantic import Field
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
                      utility)


@registry.register_component()
class VideoMilvusLTM(LTMBase):
    milvus_ltm_client: MilvusConnector
    storage_name: str = Field(default="default")
    dim: int = Field(default=128)

    def model_post_init(self, __context: Any) -> None:
        pass

    def _create_collection(self) -> None:
        # Check if collection exists
        if not self.milvus_ltm_client._client.has_collection(self.storage_name):
            index_params = self.milvus_ltm_client._client.prepare_index_params()
            # Define field schemas
            key_field = FieldSchema(
                name="key", dtype=DataType.VARCHAR, is_primary=True, max_length=256
            )
            value_field = FieldSchema(
                name="value", dtype=DataType.JSON, description="Json value"
            )
            embedding_field = FieldSchema(
                name="embedding",
                dtype=DataType.FLOAT_VECTOR,
                description="Embedding vector",
                dim=self.dim,
            )
            index_params = self.milvus_ltm_client._client.prepare_index_params()

            # Create collection schema
            schema = CollectionSchema(
                fields=[key_field, value_field, embedding_field],
                description="Key-Value storage with embeddings",
            )
            for field in schema.fields:
                if (
                    field.dtype == DataType.FLOAT_VECTOR
                    or field.dtype == DataType.BINARY_VECTOR
                ):
                    index_params.add_index(
                        field_name=field.name,
                        index_name=field.name,
                        index_type="FLAT",
                        metric_type="COSINE",
                        params={"nlist": 128},
                    )
            self.milvus_ltm_client._client.create_collection(
                self.storage_name, schema=schema, index_params=index_params
            )

            # Create index separately after collection creation
            print(f"Created storage {self.storage_name} successfully")

    def __getitem__(self, key: Any) -> Any:
        key_str = str(key)
        expr = f'key == "{key_str}"'
        res = self.milvus_ltm_client._client.query(
            self.storage_name, expr, output_fields=["value"]
        )
        if res:
            value = res[0]["value"]
            # value_bytes = base64.b64decode(value_base64)
            # value = pickle.loads(value_bytes)
            return value
        else:
            raise KeyError(f"Key {key} not found")

    def __setitem__(self, key: Any, value: Any) -> None:
        self._create_collection()

        key_str = str(key)

        # Check if value is a dictionary containing 'value' and 'embedding'
        if isinstance(value, dict) and "value" in value and "embedding" in value:
            actual_value = value["value"]
            embedding = value["embedding"]
        else:
            raise ValueError(
                "When setting an item, value must be a dictionary containing 'value' and 'embedding' keys."
            )

        # Serialize the actual value and encode it to base64
        # value_bytes = pickle.dumps(actual_value)
        # value_base64 = base64.b64encode(value_bytes).decode('utf-8')

        # Ensure the embedding is provided
        if embedding is None:
            raise ValueError("An embedding vector must be provided.")

        # Check if the key exists and delete it if it does
        if key_str in self:
            self.__delitem__(key_str)

        # Prepare data for insertion (as a list of dictionaries)
        data = [
            {
                "key": key_str,
                "value": actual_value,
                "embedding": embedding,
            }
        ]

        # Insert the new record
        self.milvus_ltm_client._client.insert(
            collection_name=self.storage_name, data=data
        )

    def __delitem__(self, key: Any) -> None:
        key_str = str(key)
        if key_str in self:
            expr = f'key == "{key_str}"'
            self.milvus_ltm_client._client.delete(self.storage_name, expr)
        else:
            raise KeyError(f"Key {key} not found")

    def __contains__(self, key: Any) -> bool:
        key_str = str(key)
        expr = f'key == "{key_str}"'
        # Adjust the query call to match the expected signature
        res = self.milvus_ltm_client._client.query(
            self.storage_name,  # Pass the collection name as the first argument
            filter=expr,
            output_fields=["key"],
        )
        return len(res) > 0

    """
    def __len__(self) -> int:
        milvus_ltm.collection.flush()
        return self.collection.num_entities
    """

    def __len__(self) -> int:
        expr = 'key != ""'  # Expression to match all entities
        # self.milvus_ltm_client._client.load(refresh=True)
        results = self.milvus_ltm_client._client.query(
            self.storage_name, expr, output_fields=["key"], consistency_level="Strong"
        )
        return len(results)

    def keys(self, limit=10) -> Iterable[Any]:
        expr = ""
        res = self.milvus_ltm_client._client.query(
            self.storage_name, expr, output_fields=["key"], limit=limit
        )
        return (item["key"] for item in res)

    def values(self) -> Iterable[Any]:
        expr = 'key != ""'  # Expression to match all active entities
        self.milvus_ltm_client._client.load(refresh=True)
        res = self.milvus_ltm_client._client.query(
            self.storage_name, expr, output_fields=["value"], consistency_level="Strong"
        )
        for item in res:
            value_base64 = item["value"]
            value_bytes = base64.b64decode(value_base64)
            value = pickle.loads(value_bytes)
            yield value

    def items(self) -> Iterable[Tuple[Any, Any]]:
        expr = 'key != ""'
        res = self.milvus_ltm_client._client.query(
            self.storage_name, expr, output_fields=["key", "value"]
        )
        for item in res:
            key = item["key"]
            value = item["value"]
            # value_bytes = base64.b64decode(value_base64)
            # value = pickle.loads(value_bytes)
            yield (key, value)

    def get(self, key: Any, default: Any = None) -> Any:
        try:
            return self[key]
        except KeyError:
            return default

    def clear(self) -> None:
        expr = (
            'key != ""'  # This expression matches all records where 'key' is not empty
        )
        self.milvus_ltm_client._client.delete(self.storage_name, filter=expr)

    def pop(self, key: Any, default: Any = None) -> Any:
        try:
            value = self[key]
            self.__delitem__(key)
            return value
        except KeyError:
            if default is not None:
                return default
            else:
                raise

    def update(self, other: Iterable[Tuple[Any, Any]]) -> None:
        for key, value in other:
            self[key] = value

    def get_by_vector(
        self,
        embedding: List[float],
        top_k: int = 10,
        threshold: float = 0.0,
        filter: str = "",
    ) -> List[Tuple[Any, Any, float]]:
        search_params = {
            "metric_type": "COSINE",
            "params": {"nprobe": 10, "range_filter": 1, "radius": threshold},
        }
        results = self.milvus_ltm_client._client.search(
            self.storage_name,
            data=[embedding],
            anns_field="embedding",
            search_params=search_params,
            limit=top_k,
            output_fields=["key", "value"],
            consistency_level="Strong",
            filter=filter,
        )

        items = []
        for match in results[0]:
            key = match.get("entity").get("key")
            value = match.get("entity").get("value")
            items.append((key, value))

        return items
